################################
## Script: 08_stm.R
## Purpose: This code runs the STM model on the bioweapons docs
## and creates the topic clusters. 
## Data In:
## 1) full verison of articles with full text
## we can't share this data but include it as reference:
## /scratch/olympus/projects/russia_ukraine_war/bioweapons_new/bioweapons_casestudy_5_20_2024_sbert_embeddings.json
## 2) public version of that data:
## data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json
## Intermediary files which are in replication 
## data folder:
## 3) DFM for STM replication:
## data/dfm_stm.rds
## 4) Model diagnositics for search K over
## topic models
## data/stm_models.rds
## 5) Topic model with 30 topics:
## data/stm_fit_30.rds
## Data Out:
## 1) Figures:
## figures/stm_topic_choose.pdf
## figures/topic_cluster_distribution.pdf
## figures/topic_bin_threshold.pdf
## figures/topic_bin_threshold.pdf
## 2) Master coding file for precision
## analysis STM:
## data/precision_stm_test_updated.rds
## 3) Individual coding files:
## a) data/precision_stm_test_ra1_updated_10_9_2024.csv
## b) data/precision_stm_test_ra2_updated_10_9_2024.csv
## 4) List of document pairs
## in topic cluster from full bioweapons
## dataset:
## data/stm_pairs.rds
  
## Notes:
## The first half of the code we can't
## share replication data for because it requires
## article full text. We include the code as a reference
## and then create a derivative dfm for replication purposes
## which we do include in the replication package. 


library(quanteda)
library(tidyverse)
library(stm)
library(viridis)
library(igraph)

## total articles - can't be shared
total_art <- jsonlite::stream_in(file("/scratch/olympus/projects/russia_ukraine_war/bioweapons_new/bioweapons_casestudy_5_20_2024_sbert_embeddings.json"))

## public version - can be shared
total_art_public <- jsonlite::stream_in(file("data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json"))

###################################
### Pre-Processing ################
## with the exception of creating the 
## source type variable, this code can't be run

## add source type variable
total_art$source_type <- ifelse(total_art$source_name %in% 
                                  c("100_percent_fed_up",
                                    "bipartisan_report",
                                    "clash_daily",
                                    "crooks_and_liars",
                                    "daily_caller",
                                    "ijr",
                                    "infowars",
                                    "judicial_watch",
                                    "natural_news",
                                    "occupy_democrats",
                                    "palmer_report",
                                    "stillness_in_the_storm",
                                    "the_federalist_papers",
                                    "the_gateway_pundit",
                                    "the_mind_unleashed",
                                    "the_political_insider",
                                    "zerohedge"),
                                "low quality",
                                ifelse(total_art$source_name %in%
                                         c("abc_news",
                                           "business_insider",
                                           "cbs_news",
                                           "cnbc",
                                           "cnn",
                                           "fox_news",
                                           "huffpost",
                                           "latimes",
                                           "msnbc",
                                           "nbc_news",
                                           "new_york_post",
                                           "new_york_times",
                                           "npr",
                                           "pbs",
                                           "politico",
                                           "slate",
                                           "star_tribune",
                                           "the_hill",
                                           "usa_today",
                                           "wall_street_journal",
                                           "washington_post",
                                           "yahoo_news"), "most popular US",
                                       ifelse(total_art$source_name %in%
                                                c("pravda",
                                                  "pravda_ru",
                                                  "rt",
                                                  "rt_ru",
                                                  "sputnik",
                                                  "sputnik_cn",
                                                  "tass",
                                                  "tass_ru"), 
                                              "russian state media",
                                              "ukranian")))

## add source time variable
total_art_public$source_type <- ifelse(total_art_public$source_name %in% 
                                  c("100_percent_fed_up",
                                    "bipartisan_report",
                                    "clash_daily",
                                    "crooks_and_liars",
                                    "daily_caller",
                                    "ijr",
                                    "infowars",
                                    "judicial_watch",
                                    "natural_news",
                                    "occupy_democrats",
                                    "palmer_report",
                                    "stillness_in_the_storm",
                                    "the_federalist_papers",
                                    "the_gateway_pundit",
                                    "the_mind_unleashed",
                                    "the_political_insider",
                                    "zerohedge"),
                                "low quality",
                                ifelse(total_art_public$source_name %in%
                                         c("abc_news",
                                           "business_insider",
                                           "cbs_news",
                                           "cnbc",
                                           "cnn",
                                           "fox_news",
                                           "huffpost",
                                           "latimes",
                                           "msnbc",
                                           "nbc_news",
                                           "new_york_post",
                                           "new_york_times",
                                           "npr",
                                           "pbs",
                                           "politico",
                                           "slate",
                                           "star_tribune",
                                           "the_hill",
                                           "usa_today",
                                           "wall_street_journal",
                                           "washington_post",
                                           "yahoo_news"), "most popular US",
                                       ifelse(total_art_public$source_name %in%
                                                c("pravda",
                                                  "pravda_ru",
                                                  "rt",
                                                  "rt_ru",
                                                  "sputnik",
                                                  "sputnik_cn",
                                                  "tass",
                                                  "tass_ru"), 
                                              "russian state media",
                                              "ukranian")))

corpus <- corpus(total_art,
                 docid_field = "article_id",
                 text_field = "content_translated")
tokens <- tokens(corpus,
                 remove_url = TRUE,
                 remove_punct = TRUE,
                 remove_numbers = TRUE,
                 verbose = TRUE)

tokens <- tokens_select(tokens,
                        pattern = stopwords("en"),
                        selection = "remove")

## to lower
tokens <- tokens_tolower(tokens)

## stemming
tokens <- tokens_wordstem(tokens)

## dfm
dfm <- dfm(tokens)

## limit size 
dfm <- dfm %>%
  dfm_trim(max_docfreq = 0.50,
           min_docfreq = 0.01,
           docfreq_type = 'prop',
           verbose = TRUE)

dim(dfm) # 3491 x 3304
table(apply(dfm, 1, function(x){all(x == 0)})) ## none removed

## covert to STM
dfm_stm <-  quanteda::convert(dfm, to = "stm")

## remove from meta file any data we can't share
dfm_stm$meta <- dfm_stm$meta %>%
  dplyr::select(final_date,
                article_url,
                language,
                summary,
                source_type)


#saveRDS(dfm_stm, "data/dfm_stm.rds")

##########################
### Run Models ##########
#########################

## can replicate from here
## with derivative stm data 

dfm_stm <- readRDS("data/dfm_stm.rds")

## k = 110 
set.seed(97405) 
fit <- 
  searchK(dfm_stm$documents, 
      dfm_stm$vocab, 
      K = c(20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120),
      max.em.its = 100,
      prevalence = ~ source_type,
      data = dfm_stm$meta,
      init.type = "Spectral"	
  )



#saveRDS(fit, "data/stm_models.rds")

fit <- readRDS("data/stm_models.rds")

##########################
### Compare Model Fit ####
##########################

results <- fit$results
results$exclus <- unlist(results$exclus)
results$semcoh <- unlist(results$semcoh)

## This plot displays the topic exclusivity
## and semantic coherence scores for a range
## of models, displayed in Figure A9 in the SI

ggplot(results,
       mapping = aes(x = semcoh,
                     y = exclus,
                     color = as.numeric(K))) + 
  geom_point() +
  scale_color_viridis() +
  theme_bw() +
  labs(x = "Semantic Coherence",
       y = "Topic Exclusivity",
       color = "Number\nof Topics") +
  annotate(geom = "label",
           label = "30 topics balances\ntopic exclusivity\nand coherence",
           x = -53,
           y = 9.8) +
  annotate(geom = "curve",
           x = -53,
           y = 9.79,
           xend = -54,
           yend = 9.76,
           curvature = -.5,
           arrow = arrow(length = unit(.2, "cm")))
## figures/stm_topic_choose.pdf


## run model with 30 topics

set.seed(97405) 
fit <- 
  stm(dfm_stm$documents, 
          dfm_stm$vocab, 
          K = 30,
          max.em.its = 100,
          prevalence = ~ source_type,
          data = dfm_stm$meta,
          init.type = "Spectral"	
  )

#saveRDS(fit, "data/stm_fit_30.rds")

## some small numerical instability in stm estimation
## of theta matrix results in larger downstream
## differences in clustering
## for this reason we provide the fitted stm objects

fit <- readRDS("data/stm_fit_30.rds")

#######################################
### Labeling Topics ###################
#######################################

## we labelled topics by reading
## documents most associated with 
## topic

labels <- data.frame(topic = 1:30,
                     label = NA,
                     description = NA)
labels$label[1] <- "pathogens, diseases, chemical attacks, potential for Russian chemical attacks"
labels$label[2] <- "US admits to biological labs, victoria nuland, stillness in the storm"
labels$label[3] <- "false flag, Russia conducting false flag, Ukranine conducting false flag, false pretexts for war"
labels$label[4] <- "html script" ## garbage topic 
labels$label[5] <- "gardening, agriculture" ## related to gardening during war, not really bioweapons docs
labels$label[6] <- "drones, aircrafts, military equipment"
labels$label[7] <- "nuclear attack, radioactivity"
labels$label[8] <- "civilians, atacks on civilians"
labels$label[9] <- "FSB, Russian security services, poisoning"
labels$label[10] <- "satellite data, European satellite data, sharing data" 
labels$label[11] <- "online misinformation, fact checking"
labels$label[12] <- "Biden official statements, Bident visit to Ukraine"
labels$label[13] <- "statements from russian defense on biolabs, Russian investigation into biolabs"
labels$label[14] <- "allegations of Russian propaganda"
labels$label[15] <- "video transcript conjunctions" ## garbage topic
labels$label[16] <- "NATO summit, NATO statements"
labels$label[17] <- "covid19, pandemic, vaccines" ## unrelated topic 
labels$label[18] <- "Russian expansionism, global order, collapse of global systems"
labels$label[19] <- "US domestic politics" ## some unrelated 
labels$label[20] <- "television transcript" ## garbage topics
labels$label[21] <- "hunter biden involvement in bioweapons, hunter's laptop"
labels$label[22] <- "research in biolabs, research on pathogens" 
labels$label[23] <- "families, children"  
labels$label[24] <- "holidays"  ## unrelated topic 
labels$label[25] <- "fox news, tucker carlson, right wing media"
labels$label[26] <- "statements from ukranian leadership"
labels$label[27] <- "china's statements on ukraine"
labels$label[28] <- "energy, gas, inflation from invasion"
labels$label[29] <- "attacks on civilians"
labels$label[30] <- "legal, investigation"

## garbage topics
garbage <- c(20,15,4)
'%ni%' <- Negate("%in%")
tokeep <- 1:30 %ni% garbage

################################
### Clustering by Topic #######
#################################

## bin topics, match on binned topics 
## not single matching, just find all documents in same bin 

## try range of bin values
bins <- seq(.1, .9, .05)
clusters <- as.data.frame(matrix(NA, nrow = nrow(total_art_public),
                                 ncol = length(bins)))


for(i in 1:length(bins)){
  theta <- fit$theta
  
  ## remove garbage topics
  theta <- theta[tokeep, ]
  
  ## turn into 1/0 - above threshold or not
  theta <- apply(fit$theta, 2, function(x){
    x[x >= bins[i]] <- 1
    x[x < bins[i]] <- 0
    return(x)
   })
  
  ## create topic clusteres 
  ## by pasting together names of topics
  ## above threhsold for each document 
  clusters[,i] <- apply(theta, 1, function(x){
    topic <- which(x == 1)
    if(length(topic) > 0){
      label <- paste0(topic, collapse = "_")
    } else {
      label <- NA
    }
    return(label)
  })
}

## clusters dataframe
## for each document assigns name
## of topic cluster for each threshold (separated by columns)

## how many documents not assigned to any clusters
count_unassigned <- apply(clusters, 2, function(x){
  sum(is.na(x))
})

## how many clusters have only 1 document
colnames(clusters) <- paste0("cutoff_", bins, sep = "")
counts <- list()

for(j in 1:ncol(clusters)){
  counts[[j]] <- clusters %>%
    group_by(clusters[,j]) %>%
    summarize(n = n()) %>%
    summarize(total_singleton = sum(n == 1),
              bin = colnames(clusters)[j])
}
counts <- bind_rows(counts)
counts$unassigned <- count_unassigned

## This figure displays for different
## binning thresholds the count of unassigned docs
## by the count of single topic clusters
ggplot(counts,
       mapping = aes(x = unassigned,
                     y = total_singleton)) +
  geom_point() +
  theme_bw() +
  labs(x = "Count of Documents Unassigned to Any Topic Cluster",
       y = "Count of Singleton Topic Clusters") 
## figures/topic_cluster_distribution.pdf
##5x5

## This plot shows just the count of documents
## unassigned by binning threshold 
## These results are displayed in the SI
## Figure A10
counts$cutoff <- as.numeric(gsub("cutoff_", "", counts$bin))
ggplot(counts,
       mapping = aes(y = unassigned,
                     x = cutoff)) +
  geom_point() +
  theme_bw() +
  labs(y = "Count of Documents Unassigned to Any Topic Cluster",
       x = "Binning Threshold")
## figures/topic_bin_threshold.pdf
## 5x5

## set cutoff at .2
total_art_public$topic_cluster <- clusters$cutoff_0.2

#############################
### Create Precision Set ####
#############################

## create adjacency network
adj_matrix <- matrix(0, nrow = nrow(total_art_public),
                     ncol = nrow(total_art_public))

## fill in adjacency network, where any 1 indicates
## in shared topic cluster
## note: I could have done this in a simpler way,
## but I preserved this correct, albeit circuitous
## method in order to preserve replication of downstream data
for(i in 1:nrow(total_art_public)){
  if(!is.na(total_art_public$topic_cluster[i])){
    index <- which(total_art_public$topic_cluster == 
                     total_art_public$topic_cluster[i])
    adj_matrix[i, index] <- 1
  }
}

## create eddgelist 
net <- graph_from_adjacency_matrix(adj_matrix, mode = "undirected",
                       weighted = TRUE)
net <- as_edgelist(net)
colnames(net) <- c("ego", "alter")
net <- as.data.frame(net)

## remove self loops
## don't need to remove duplicated ego, alter pairs
## b/c that's already done with functions above
net <- net[net$ego != net$alter, ]

## adding in variables
net$ego_id <- total_art_public$article_id[net$ego]
net$alter_id <- total_art_public$article_id[net$alter]
net$ego_article <- total_art_public$summary[net$ego]
net$alter_article <- total_art_public$summary[net$alter]
net$ego_date <- as.Date(total_art_public$final_date[net$ego])
net$alter_date <- as.Date(total_art_public$final_date[net$alter])

## remove outside 5 day window
net$date_range <- net$ego_date >= net$alter_date - days(5) & 
  net$ego_date <= net$alter_date + days(5)
net <- net %>%
  filter(date_range == TRUE)

## take precision sample
set.seed(08540)

samp <- sample(1:nrow(net), 100, replace = FALSE)
samp <- net[samp, ]

samp$same_subject <- NA
samp$same_claim <- NA
samp$refutation <- NA
samp$RA <- c(rep("RA1", 50),
             rep("RA2", 50))

## creating master file 
#saveRDS(samp, "data/precision_stm_test_updated.rds")

## RA files
#write.csv(samp[samp$RA == "RA1" , c("ego",
#                             "alter",
#                             "ego_article",
#                           "alter_article",
#                            "same_subject",
#                            "same_claim",
#                            "refutation")], "data/precision_stm_test_ra1_updated_10_9_2024.csv")

#write.csv(samp[samp$RA == "RA2", c("ego",
 #                                     "alter",
 #                                     "ego_article",
#                                      "alter_article",
#                                   "same_subject",
#                                      "same_claim",
#                                      "refutation")], "data/precision_stm_test_ra2_updated_10_9_2024.csv")


## saving pairs
#saveRDS(net, "data/stm_pairs.rds")
